/*
  Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

/// \file
/// Contains the cell_tracker class.

// (This file is included automatically by grid.h)

#include <iterator>
#include "blitz/numinquire.h"
#include "ray.h"
#include "hilbert.h"
#include "vecops.h"
#include "mcrx-debug.h"

#ifdef HAVE_BOOST_SERIALIZATION
#include <boost/serialization/serialization.hpp>
#endif

namespace mcrx {
  class c_ordering;
  class hilbert_ordering;
  template <typename> class cell_tracker;
};

/** Class defining normal c-order, where the z-axis has lowest
    stride. It has no state, it only translates between positions and
    octants. */
class mcrx::c_ordering {
public:

  void push(uint8_t octant) {};
  void pop() {};
  void reset() {};

  blitz::TinyVector<uint8_t, 3> pbits(uint8_t octant) {
    blitz::TinyVector<uint8_t, 3> p;
    p[2] = octant & '\x01';
    p[1] = (octant & '\x02') >> 1;
    p[0] = (octant & '\x04') >> 2;
    return p;
  };

  uint8_t octant(blitz::TinyVector<uint8_t, 3> pbits) {
    return pbits[2] | (pbits[1]<<1) | (pbits[0]<<2);
  };
};

/** Class defining hilbert order. It uses a stack of hilbert_state
    objects to determine the relationship between positions and
    octants. */
class mcrx::hilbert_ordering {

  /// The icpc compiler generated an unaligned access exception
  /// without manually specifying the alignment.
  
  BZ_ALIGN_VARIABLE(hilbert::hilbert_state,state_[hilbert::cell_code::maxlevel_],16)
  uint8_t n_;

public:
  hilbert_ordering() : n_(0) { state_[0]=hilbert::hilbert_state(); };

  void push(uint8_t octant) {
    assert(n_<hilbert::cell_code::maxlevel_);
    ++n_;
    state_[n_] = state_[n_-1];
    state_[n_].descend(octant);
  };

  void pop() {
    assert(n_>0);
    --n_;
  };

  void reset() { n_=0; };

  blitz::TinyVector<uint8_t, 3> pbits(uint8_t octant) {
    return state_[n_].octant2pos(octant); };

  uint8_t octant(blitz::TinyVector<uint8_t, 3> pbits) {
    return state_[n_].determine_octant(pbits); };
};


/** This class implements the octree traversal using a stack. It is
    effectively a cell iterator that also knows how to do the
    intersection tests.

    The current state of the traversal is determined by the
    cell_code. The cell_ array just keeps the pointers corresponding
    to the nodes.
 */
template<typename T> 
class mcrx::cell_tracker : 
  public std::iterator<std::forward_iterator_tag, 
		       typename mcrx::grid_cell<T> >  {
public:
  typedef grid_cell<T> T_cell;
  typedef adaptive_grid<T> T_topgrid;
  typedef hilbert::qpoint T_qpoint;
  typedef hilbert::cell_code T_code;
  //typedef c_ordering T_ordering;
  typedef hilbert_ordering T_ordering;

private:
  static const uint8_t maxlevel_ = hilbert::cell_code::maxlevel_;

  /// The current task.
  int task_;

  /// The code for the current cell.
  T_code code_;

  /// The position of the current cell.
  T_qpoint pos_;

  /// Pointer to the top-level grid
  T_topgrid* g_;

  /// The current level of the traversal
  uint8_t n_;

  /** Flag used by the dfs_bidir traversal to determine whether the
      current node has already been descended into. */
  bool descended_;

  /** The ordering state determines the storage order. */
  T_ordering order_; 

  /// The cell stack.
  T_cell* cell_[maxlevel_];

#ifdef HAVE_BOOST_SERIALIZATION
  friend class boost::serialization::access;
  /// Serialization support.
  /** The state of the tracker is fully determined by the qpos of the
      current cell, so that's the only thing we serialize. */
  template<class T_arch>
  void serialize(T_arch& ar, const unsigned int version) {
    ar & code_;
  };
#endif

  /** Descends into the specified octant in the current cell. Adds
      a digit to the 3D position and cell code and pushes the
      order.  */
  void push(uint8_t oct) {
    assert(!is_leaf());
    assert(n() != maxlevel_);
    assert(oct<9);
    pos_.add_right(order_.pbits(oct));
    order_.push(oct);
    code_.add_right(oct);
    cell_[n()+1] = cell()->sub_grid()->get_cell(oct);
    ++n_;
  };

  /** Ascends into the parent of the current node. Removes a digit
      from the pos and code and pops the order. */
  void pop() {
    assert(n()!=0);
    pos_>>=1;
    order_.pop();
    code_.truncate();
    // zero out unused part of stack for debugging
    cell_[n()]=0;
    --n_;
  };

  /** Moves to the specified octant in the current grid. This is a pop
      followed by a push, but we code it separately to avoid touching
      the parent cell, since we can just offset the pointer compared
      to the oct we are in now. */
  void move(uint8_t oct) {
    assert(n()>0);

    cell_[n()] += oct-current_oct();

    pos_>>=1;
    order_.pop();
    code_.truncate();

    pos_.add_right(order_.pbits(oct));
    order_.push(oct);
    code_.add_right(oct);
  };

  /** Increments to the next octant in the current node. If there are
      no more octants, we pop until we find one. */
  void increment() {
    if(n()==0) {
      // incrementing at top level puts us in end.
      set_end();
    }
    else if(current_oct() <7)
      move(current_oct()+1);
    else {
      // pop and try to increment again.
      pop();
      increment();
    }
  };

  /** Put traversal in end state. */
  void set_end() {
    n_ = 0xffU;
    order_.reset();
    code_.truncate(0); ++code_;
    pos_ >>= pos_.level(); ++pos_;
  };

  uint8_t n() const { return n_; };

  uint8_t current_oct() const { return code_.code()&7; };


public:
  // Default constructor constructs an invalid tracker with no grid.
  cell_tracker() : task_(mpi_rank()), g_(0) {
    for(int i=0; i<maxlevel_; ++i) cell_[i]=0;
    set_end(); 
  };

  /** Creates a cell tracker to the root cell of a grid. */
  cell_tracker(T_topgrid* g) : 
    task_(mpi_rank()), g_(g) {
    for(int i=0; i<maxlevel_; ++i) cell_[i]=0;
    restart();
  };

  /** Creates a cell tracker to the root cell of a grid. */
  cell_tracker(T_topgrid& g) : 
    task_(mpi_rank()), g_(&g) {
    for(int i=0; i<maxlevel_; ++i) cell_[i]=0;
    restart();
  };

  /** Copy constructor only copies the necessary elements of the stack. */
  cell_tracker(const cell_tracker& rhs) :
    task_(mpi_rank()), g_(rhs.g_), n_(rhs.n_), code_(rhs.code_), 
    pos_(rhs.pos_), order_(rhs.order_), descended_(rhs.descended_) {
    if(n_!=0xff) 
      for(int i=0; i<=n_; ++i) cell_[i] = rhs.cell_[i];
  };

  /** Does one step in a depth-first traversal of the tree. This
      visits each cell exactly once, so skips previously visited cells
      when ascending the tree. */
  void dfs() {
    assert(!at_end());
    if(is_leaf()) {
      // if we are at a leaf, we try to increment. this may pop.
      increment();
    }
    else {
      // otherwise we descend to octant 0 in the subtree
      push(0);
    }
  };


  /** Does one step in a depth-first traversal of the tree, stopping
      also on the ascent. This visits all non-leaf cells twice and
      leaf cells once. */
  void dfs_bidir() {
    assert(!at_end());
    if(is_leaf() || descended_) {
      if(n()==0)
	set_end();
      else if(current_oct() <7) {
	move(current_oct()+1);
	descended_=false;
      }
      else {
	// pop
	pop();
	descended_=true;
      }
    }
    else {
      // for non-leaf cells, descend
      push(0);
      descended_=false;
    }
  };


  bool operator==(const cell_tracker& rhs) const {return code_==rhs.code_;};
  bool operator!=(const cell_tracker& rhs) const {return code_!=rhs.code_;};

  // Iterator functionality. When used with the increment operators,
  // the tracker traverses only leaf cells. Note that due to the need
  // to make a copy for the postfix operator, it is quite expensive.

  T_cell& operator*() const { return *cell(); };
  T_cell* operator->() const { return cell(); };
  /** Increment operator only stops at leaf cells on this task. */
  cell_tracker& operator++ () { // prefix
    dfs(); 
    while(!at_end() && (!is_leaf() || task()!=task_)) dfs(); 
    return *this; };
  cell_tracker operator++ (int){ // postfix
    cell_tracker i (*this); operator++ (); return i;};

  /// Set tracker to invalid.
  void reset () { set_end(); };

  /// Restart the traversal, setting the tracker to the root node.
  void restart() {
    assert(g_);
    n_=0; cell_[n()] = &g_->c_;
    descended_=false;
    order_.reset();
    code_=T_code();
    pos_=T_qpoint();
  };

  /// Restart the traversal, setting the tracker to the root node.
  void restart(T_topgrid* g) {
    g_=g; restart(); };

  bool at_end() const { return n()==0xffU; };

  T_topgrid* grid() const { return g_; };

  int task() const {return at_end() ? -1 : cell()->task(); };

  const T_qpoint& qpos() const { return pos_; };
  T_code code() const { return code_; };

  /// Nop for this class.
  void reset_intersections() const {};

  /// Converts to a bool indicating whether the tracker points to a
  /// valid cell that resides on this task.
  operator bool() const {
    return !at_end() && (cell_[n()]->task()==task_);};

  T_cell* cell() const {assert(!at_end()); return cell_[n()];};

  /** Finds the ray intersection with the next cell boundary. Note
      that the position of the ray along its trajectory is completely
      irrelevant to this function, because it only looks at the
      intersections out of this cell.

      The calculation of the column density should really be pulled
      out of this, the tracker should not need to know anything about
      what's in the cells.
  */
  T_float
  intersection_from_within (const ray_base& r, T_float max_len, 
			    T_densities& dn, bool& hit_max) {
    assert(is_leaf());
    assert(!at_end());
    const int t = task();

    // save cell pointer since we will update ourselves below
    const T_cell* c=cell();

    // to get to the next cell, we call "advance" and then as many
    // push as necessary to get to a leaf cell.
    T_float len = raycast_dfs_advance(r);
    while(!at_end() && !is_leaf()) raycast_dfs_push(r);

    // if we hit max length, invalidate tracker
    if(r.length()+len>=max_len) {
      len=max_len-r.length();
      hit_max=true;
      reset();
    }
    else {
      hit_max=false;
    }

    dn=len*c->data()->get_absorber().densities();

    assert(at_end() || (c!=cell()));
    return len;
  };

  /** Finds the fractional distance along a ray that corresponds to a
      certain column density through the cell. Since the cells are
      uniform density, this is trivial for this grid. 

      This should also be taken out of the tracker.
  */
  T_float column_to_distance_fraction(T_float fn,
				      const vec3d& pos, 
				      const vec3d& dir,
				      T_float dl) const {
    return fn;
  };

  bool is_leaf() const { return cell()->is_leaf(); };

  vec3d getsize() const {
    return g_->getsize()/(1<<n()); }

  T_float getsize(int dim) const {
    return g_->getsize(dim)/(1<<n()); }

  vec3d getinvsize() const {
    return (1<<n())/g_->getsize(); }

  vec3d getmin() const { 
    return g_->real_point(qpos()); };

  T_float getmin(int dim) const { 
    return g_->real_point(qpos(), dim); };

  vec3d getmax() const { 
    return g_->real_point(qpos())+getsize(); };
 
  T_float getmax(int dim) const { 
    return g_->real_point(qpos(),dim)+getsize(dim); };
 
  vec3d getcenter() const {
    return g_->real_point(qpos())+0.5*getsize(); };

  bool contains(const vec3d& p) const {
    if(!g_->in_grid(p)) return false;
    else return qpos().contains(g_->qpoint(p)); };

  void assert_contains(const vec3d& p) const {
    assert(contains(p)); };

  T_float volume() const {
    // Need uint64_t here otherwise we will overflow a 32-bit int for
    // levels>10.
    return product(g_->getsize())/(uint64_t(1)<<(3*n())); };

  typename T_cell::T_data* data() const {
    return cell()->data(); };

  std::ostream& print(std::ostream& os) const {
    os << "Tracker: " << cell() << ", " << task() << ", " << qpos();
    return os;
  }

  void locate(T_qpoint qp, bool accept_outside);
  void locate(T_code qp);

  void raycast_dfs_push(const ray_base& r);
  T_float raycast_dfs_advance(const ray_base& r);

};


/** Descend to the grid cell that contains the qpoint.  If
    accept_outside is true, the point is truncated to lie within the
    current cell before the descent is started, otherwise a point
    outside the current cell implies failure. If the qpoint has
    limited resolution, the location stops when that cell is found,
    otherwise it goes until a leaf cell is found.  */
template<typename T> 
void
mcrx::cell_tracker<T>::locate(T_qpoint qp,
			      bool accept_outside)
{
  // If the point asked for is above where we are, we can't find it
  if(qp.level()<n()) {
    reset();
    return;
  }
    
  // if accept_outside is true, we truncate the position to the cell at
  // the resolution of the qpoint
  if(accept_outside) {
    T_qpoint cmin(qpos());
    cmin.extend_low(qp.level());
    T_qpoint cmax(qpos());
    cmax.extend_high(qp.level());
    assert(cmin.level()==qp.level());
    assert(cmax.level()==qp.level());

    // this operates on the bare ints, which works because we've
    // ensured they are the same level
    qp = T_qpoint(truncate_to_box(qp.pos(), 
				  cmin.pos(), 
				  cmax.pos()),
		  qp.level());
    assert(qpos().contains(qp));
  }
  else {
    // check that qp is contained in the cell
    if(!qpos().contains(qp)) {
      reset();
      return;
    }
  }

  while(!is_leaf() && n()<qp.level()) {
    // figure out the octant and descend
    const uint8_t oct= order_.octant(qp.extract_level(n()+1));
    push(oct);
  };
}


/** Descend to the grid cell with the specified code. */
template<typename T> 
void
mcrx::cell_tracker<T>::locate(T_code c)
{
  // If the code asked for is above where we are, we can't find it
  if(c.level()<n()) {
    reset();
    return;
  }

  // check that qp is contained in the cell
  if(!code().contains(c)) {
    reset();
    return;
  }

  while(!is_leaf() && n()<c.level()) {
    push(c.extract_level(n()+1));
  };
}


/** The push dfs raycasting operation from Laine & Karras. Descends to
    the first child that the ray will enter. */
template<typename T> 
void
mcrx::cell_tracker<T>::raycast_dfs_push(const ray_base& r) 
{
  assert(!is_leaf());

  // get the sign of the ray direction vector.  The value for axes the
  // ray is parallel to doesn't matter.
  vec3d dirsign;
#pragma forceinline recursive
  dirsign = where(r.direction()>0, 1, -1);

  vec3d curmin,curmax;
#pragma forceinline recursive
  curmin = getmin();
#pragma forceinline recursive
  curmax = getmax();

  // calculate t-values for center of the current cell, which is the
  // split point for the subcells. When the ray is parallel to an axis 
  vec3d tc;
#pragma forceinline recursive
  tc = r.inverse_direction()*(0.5*(curmin+curmax) - r.position());

  // calculate t-values for the entry point into the current cell.
  T_float tmax=
    (r.direction()[0]==0) ? 
    blitz::neghuge(T_float()) :
    r.inverse_direction()[0]*
    ((r.direction()[0]>0 ? curmin[0] : curmax[0]) - r.position()[0]);

  for(int i=1; i<3; ++i) {
    const T_float ttemp=
      (r.direction()[i]==0) ? 
      blitz::neghuge(T_float()) :
      r.inverse_direction()[i]*
      ((r.direction()[i]>0 ? curmin[i] : curmax[i]) - r.position()[i]);
    tmax=std::max(tmax,ttemp);
  }

  // the ray will enter on the -direction() side of the split in
  // dimensions where t0<tc. For axes where the ray is parallel to the
  // axis we need to compare the ray position to the centerpoint.
  T_qpoint::T_pbits pbits;
#pragma ivdep
  for(int i=0; i<3; ++i)
    pbits[i] = 
      (r.direction()[i] != 0) ?
      (dirsign[i]*tmax < dirsign[i]*tc[i] ? 0:1) :
      (r.position()[i]<0.5*(curmin[i]+curmax[i]) ? 0:1);

  const uint8_t oct = order_.octant(pbits);
  DEBUG(3,std::cout << "Raycast push in " << code() << " " << qpos() << " to oct " << int(oct) << std::endl;);
  push(oct);
  DEBUG(3,std::cout << "\tnew: " << code() << " " << qpos() << std::endl << "\t" << getmin() << " - " << getmax() << std::endl;);

};

/** The advance/pop dfs raycasting operation from Laine &
    Karras. Advances to the next sibling of the first (grand)parent
    that the ray does not exit. */
template<typename T> 
mcrx::T_float
mcrx::cell_tracker<T>::raycast_dfs_advance(const ray_base& r) 
{
  assert(is_leaf());

  // calculate t-values for the exit point from the current cell.
  vec3d x1, t1, dirsign;
  T_float tmin=blitz::huge(T_float());
  for(int i=0; i<3; ++i) {
#pragma ivdep
    x1[i] = (r.direction()[i]<0) ? getmin(i) : getmax(i);
#pragma ivdep
    dirsign[i] = (r.direction()[i]<0) ? -1 : 1;
#pragma ivdep
    t1[i] = r.inverse_direction()[i]*(x1[i] - r.position()[i]);
    tmin = std::min(tmin, t1[i]);
  }

  // we update the position by adding +-1 (depending on the ray
  // direction) to the dimension(s) where t1==tmin. This is safe for
  // rays parallel to an axis because t1==inf for those axes.
  T_qpoint::T_pos pnext(qpos().pos());
  T_qpoint::T_pos delta(pnext);
  T_qpoint::T_pos::T_numtype maxdelta=0;

  for(int i=0; i<3; ++i) {
#pragma ivdep
    pnext[i] += (t1[i]==tmin) ? dirsign[i] : 0;
#pragma ivdep
    delta[i] ^= pnext[i];
    maxdelta = std::max(maxdelta, delta[i]);
  }
  // find the msb bit they differ on. that's the number of levels we
  // need to pop.
  uint8_t poplevels = log2(maxdelta);

  DEBUG(3,std::cout << "Raycast advance in " << code() << " " << qpos() << " to pnext " << T_qpoint(pnext,qpos().level()) << " poplevels " << int(poplevels) << std::endl;);

  if(poplevels<n()) {
    // the pbits of the next position will be the appropriate bits of
    // pnext. shift it out now before we blow away poplevels.
#pragma forceinline recursive
    pnext>>=poplevels;
    for(; poplevels>0; --poplevels) pop();
    
    // now we move to the next octant, but that is determined by the
    // ordering in the level *above*. Instead of calling pop/push
    // here, we code it explicitly to avoid unnecessary work
    order_.pop();
    pos_>>=1;

    const uint8_t oct = order_.octant(pnext&1);

    // get current oct *before* changing code
    cell_[n()] += oct-current_oct();

    code_.truncate();
    pos_.add_right(order_.pbits(oct));
    order_.push(oct);
    code_.add_right(oct);
  }
  else {
    // if we need to pop the entire stack, that means we exited the entire grid
    reset();
  }

  DEBUG(3,std::cout << "\tnew: " << code() << " " << qpos() << std::endl << "\t" << getmin() << " - " << getmax() << std::endl;);
  
  return tmin;
}


